{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from embedding.vit4embedding import ViTForEmbedding,VitConfigForEmbedding\n",
    "from transformers import AutoImageProcessor\n",
    "from PIL import Image\n",
    "import numpy as np \n",
    "import torch as t "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of ViTForEmbedding were not initialized from the model checkpoint at models/google/vit-large-patch16-224-in21k and are newly initialized: ['embeddings.mask_token', 'output_linear.bias', 'output_linear.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.\n"
     ]
    }
   ],
   "source": [
    "model_name_or_path = 'models/google/vit-large-patch16-224-in21k'\n",
    "\n",
    "\n",
    "config = VitConfigForEmbedding.from_pretrained(model_name_or_path,embedding_dim=1024)\n",
    "\n",
    "image_embedding_model = ViTForEmbedding.from_pretrained(pretrained_model_name_or_path=model_name_or_path,config=config, device_map='cuda:0')\n",
    "image_processor = AutoImageProcessor.from_pretrained(model_name_or_path)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "images_list = [Image.open(i).convert('RGB') for i in [\n",
    "    'data/test_data/截屏2024-11-21 13.20.05.png'\n",
    "]]\n",
    "encoded_input = image_processor(images=images_list, return_tensors=\"pt\")\n",
    "\n",
    "for k in encoded_input.keys():\n",
    "    encoded_input[k] = encoded_input[k].to(image_embedding_model.device)\n",
    "\n",
    "last_hidden_states, pool_output = image_embedding_model(**encoded_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 1024])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pool_output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0101,  0.2668, -0.0197,  ...,  0.1814,  0.1021,  0.0779]],\n",
       "       device='cuda:0', grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pool_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2024-12-03 07:04:38,754] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['models/test/image_embedding_001/preprocessor_config.json']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "image_embedding_model.save_pretrained(\"models/test/image_embedding_001\")\n",
    "image_processor.save_pretrained(\"models/test/image_embedding_001\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['models/image_embedding/image_embedding_001/preprocessor_config.json']"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## 保存一下模型，用于后面训练\n",
    "final_image_embedding = \"models/image_embedding/image_embedding_001\"\n",
    "image_embedding_model.save_pretrained(final_image_embedding)\n",
    "image_processor.save_pretrained(final_image_embedding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_embedding_model2 = ViTForEmbedding.from_pretrained(pretrained_model_name_or_path='models/test/image_embedding_001', device_map='cuda:0')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 1024])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "last_hidden_states, pool_output2 = image_embedding_model2(**encoded_input)\n",
    "pool_output2.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0101,  0.2668, -0.0197,  ...,  0.1814,  0.1021,  0.0779]],\n",
       "       device='cuda:0', grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pool_output2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t.allclose(pool_output,pool_output2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.6441,  0.0309,  0.8128,  ..., -0.0981, -0.1634, -0.0594]],\n",
       "       device='cuda:0', grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "last_hidden_states + pool_output2[:, : image_embedding_model2.config.hidden_size] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.6542, -0.2359,  0.8325,  ..., -0.2795, -0.2655, -0.1374]],\n",
       "       device='cuda:0', grad_fn=<SelectBackward0>)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hz_net_v2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
